import torch.nn as nn
import torch.nn.functional as F
import torch
from torch_geometric.data import Data, Batch
from torch.autograd import Variable
from skimage.segmentation import slic, mark_boundaries
from skimage import io
import cv2
import sys
import numpy as np
from skimage import segmentation
import matplotlib.pyplot as plt
import os
from collections import defaultdict
use_cuda = torch.cuda.is_available()


def rotate(image,angle):
    return np.rot90(image,angle)

def rank(l):
    """
    Rank/Frequency Counter
    This function efficiently calculates number rank/frequency from a ndarray
    e.g. rank([3,5,2,2,2,2,1,1]) -> ((4,2), (2,1), (1,3), (1,5))


    Keyword arguments:
    l  --a 1-dimensional numpy array
    """
    d = defaultdict(int)
    for i in l:
        d[i] += 1
    return sorted(d.items(), key=lambda x: x[1], reverse=True)


def adjacency(sMask):
    """
    Adjacency Matrix
    This function (not so efficiently) finds bidirectional adjacency edges
    -from a SLIC superpixel mask.


    Keyword arguments:
    sMask --a 2-dimensional numpy array generated by SLIC algorithm
    """
    curr = 0
    adj = []
    for y in range(sMask.shape[0]):
        #print(y)
        for x in range(sMask.shape[1]):
            if x >= sMask.shape[1] - 1: #reach end of row
                if y >= sMask.shape[0] - 1: #reach end of graph
                    return adj
                else: #switch to new line, update curr
                    curr = sMask[0, y+1]
                    continue
            else: #still iterating row
                if sMask[x,y] != curr:
                    if (curr, sMask[x,y]) not in adj:
                        #new edge not in adjacency list
                        adj.append((curr, sMask[x,y]))
                        adj.append((sMask[x,y], curr)) #bidirectional
                        curr = sMask[x,y]
                    else: #already in adj
                        curr = sMask[x,y]
                        continue
                else:
                    continue
    return adj


def adjacency3(sMask, mask_length):
    """
    Adjacency Matrix
    This function efficiently finds bidirectional adjacency edges
    -from a SLIC superpixel mask at the cost of higher RAM/cache usage


    Keyword arguments:
    sMask       --a 2-dimensional numpy array generated by SLIC algorithm
    mask_length --number of unique SLIC superpixels in sMask
    """
    curr = 0
    adj = []
    edge_visited = np.full((mask_length, mask_length), False)
    for y in range(sMask.shape[0]):
        for x in range(sMask.shape[1]):
            if x >= sMask.shape[1] - 1: #reach end of row
                if y >= sMask.shape[0] - 1: #reach end of graph
                    return adj
                else: #switch to new line, update curr
                    curr = sMask[0, y+1]
                    continue
            else: #still iterating row
                target = sMask[x,y]
                if target != curr:
                    if edge_visited[curr, target] == False or \
                       edge_visited[target, curr] == False:
                        #new edge not in adjacency list
                        adj.append((curr, target))
                        adj.append((target, curr)) #bidirectional
                        edge_visited[curr, target] = True
                        edge_visited[target, curr] = True
                        curr = sMask[x,y]
                    else: #already in adj
                        curr = sMask[x,y]
                        continue
                else:
                    continue
    return adj


def center(pixel):
    """
    Center of Superpixels
    This function efficiently finds weighted center of SLIC superpixel
    -tiles


    Keyword arguments:
    pixel --(x,y) tuple pixel corrdinates list from a SLIC superpixel tile.
    """
    x_sorted = sorted(pixel.copy())
    y_sorted = sorted(pixel, key=lambda tup: tup[1]) #sorted by y
    assert(len(x_sorted) == len(y_sorted))
    mid = int(len(x_sorted) / 2)
    mid_x = int(x_sorted[mid][0])
    mid_y = int(y_sorted[mid][1]) #x, y individual middle points
    return (mid_x, mid_y)


def euclidean_dist(a, b):
    """
    Euclidean Distance
    Calculate Eucliden Distance between two matricies.


    Keyword arguments:
    a  --first matrix's ndarray
    b  --second matrix's ndarray, must match input a's shape
    """
    return np.linalg.norm(a-b)


def mse(a,b):
    """
    Mean Squared Error
    Calculate Mean Squared Error between two matricies.


    Keyword arguments:
    a  --first matrix's ndarray
    b  --second matrix's ndarray, must match input a's shape
    """
    return ((a - b)**2).mean(axis=None)


def fnorm(a, b):
    """
    Frobenius Norm Distance
    Calculate Frobenius Norm Distance between two matricies.


    Keyword arguments:
    a  --first matrix's ndarray
    b  --second matrix's ndarray, must match input a's shape
    """
    a_ss = np.sum(a = np.square(a)) #sum of squares
    b_ss = np.sum(a = np.square(b))
    return np.sqrt(np.absolute(a_ss - b_ss))


def chisq_dist(a, b, gamma = 1):
    """
    Chi-Square Distance
    Calculate Chi-Square Distance between two matricies.


    Keyword arguments:
    a  --first matrix's ndarray
    b  --second matrix's ndarray, must match input a's shape
    """
    numerator = np.sum(np.square(a-b))
    denomenator = np.sum(a+b)
    dist = 0.5 * numerator / denomenator
    return np.exp(- gamma * dist)


def one_label_loss(gt_percent, predict, moe, batch_node_num):
    """
    Proposed Loss Function
    Our proposed Loss Functions calculates cost of training batch using
    -GCN's output graphs and weak image level annotations.
    For more information, please refer to our paper.


    Keyword arguments:
    gt_percent       --Ground-Trueth percent, a weak image-level annotation
    predict          --GCN module output, gradient required
    moe              --Margin of Error, a weak image-level annotation
    batch_node_num   --integer list of node numbers per image in batch
    """
    curr_index = 0
    batch_top_k_loss = []
    batch_bottom_k_loss = []
    batch_pairwise_loss = []
    positive_num = 0.00000001
    negative_num = 0.00000001
    for i in range(len(gt_percent)):
        total_length = batch_node_num[i] #one graph length
        predict_slice = torch.narrow(input = predict, dim = 0, start = curr_index, length = total_length)
        curr_index += total_length
        one_gt_percent = gt_percent[i]
        one_moe = moe[i]
        select = torch.tensor([0])
        if use_cuda:
            select = select.to('cuda')

        threshold_ceil = int(total_length * (one_gt_percent - one_moe)) #100 * (0.8 - 0.1) = top 70 %
        if threshold_ceil < 0:
            threshold_ceil = 0
        threshold_floor = int(total_length * (1.0 - one_gt_percent - one_moe)) #100 * (1 - 0.8 - 0.1) = bottom 10 %
        if threshold_floor < 0:
            threshold_floor = 0

        top_k, _ = torch.topk(input = predict_slice, k = threshold_ceil, dim = 0, largest = True, sorted = False)
        bottom_k, _ = torch.topk(input = predict_slice, k = threshold_floor, dim = 0, largest = False, sorted = False)

        top_k_mean = torch.mean(top_k,dim=0)
        bottom_k_mean = torch.mean(bottom_k,dim=0)

        predict_slice = None
        top_k = None
        select = None
        bottom_k = None
        loss_fn = nn.SmoothL1Loss()
        if use_cuda:
            temp_ones = torch.ones(1, dtype = torch.float).to('cuda')
            temp_zeros = torch.tensor([-1], dtype = torch.float).to('cuda')
            temp_ground = torch.zeros(1, dtype = torch.float).to('cuda')
            if threshold_ceil > 0:
                #top_k_loss = F.l1_loss(top_k_mean, temp_ones)
                top_k_loss = loss_fn(top_k_mean, temp_ones)
                positive_num += top_k_loss.detach().cpu().numpy()
            else:
                top_k_loss = None

            if threshold_floor > 0:
                #bottom_k_loss = F.l1_loss(bottom_k_mean, temp_zeros)
                bottom_k_loss = loss_fn(bottom_k_mean, temp_zeros)
                negative_num += bottom_k_loss.detach().cpu().numpy()
            else:
                bottom_k_loss = None
            temp_ones = None
            temp_zeors = None
        else:
            if threshold_ceil > 0:
                #top_k_loss = F.l1_loss(top_k_mean, torch.ones(1, dtype = torch.float))
                top_k_loss = loss_fn(top_k_mean, torch.ones(1, dtype = torch.float))
                positive_num += 1.0
            else:
                top_k_loss = None

            if threshold_floor > 0:
                #bottom_k_loss = F.l1_loss(bottom_k_mean, torch.zeros(1, dtype = torch.float))
                bottom_k_loss = loss_fn(bottom_k_mean, torch.zeros(1, dtype = torch.float))
                negative_num += 1.0
            else:
                bottom_k_loss = None
        batch_top_k_loss.append(top_k_loss)
        batch_bottom_k_loss.append(bottom_k_loss)
    top_k_loss = None
    bottom_k_loss = None
    pairwise_loss = None
    print("-------------------------------------------------------------------------------")
    print("Targeted Regions Losses Per Image")
    print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_top_k_loss])
    print("Background Regions Losses Per Image")
    print([round(float(x.data.cpu().detach().numpy()),2) if x is not None else -1.00 for x in batch_bottom_k_loss])
    print("-------------------------------------------------------------------------------")

    for t, b, g, a in zip(batch_top_k_loss, batch_bottom_k_loss, gt_percent, moe):
        if top_k_loss is None and t is not None:
            top_k_loss =  (g - a) * t
        elif t is not None:
            top_k_loss +=  (g - a) * t
        if bottom_k_loss is None and b is not None:
            bottom_k_loss = (1.0 - g - a) * b
        elif b is not None:
            bottom_k_loss += (1.0 - g - a) * b
    return top_k_loss, bottom_k_loss




def plot_grad_flow(named_parameters):
    """
    Utility Function-Visualize Gradient Flow
    This utility function can assist in checking gradient flow between layers.


    Keyword arguments:
    named_parameters   --module's parameter()
    """
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.show()

def fuse_results(multi_args):
    (outgcn_numpy, segments_copy) = multi_args
    for segInd in range(len(outgcn_numpy)):
        segments_copy[segments_copy == segInd] = outgcn_numpy[segInd]
    return segments_copy
